In [61]:
Copied!
! lamin load scprint
! lamin load scprint
Exception ignored in: <function _releaseLock at 0x7ff2bfa68310>
Traceback (most recent call last):
File "/home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/logging/__init__.py", line 228, in _releaseLock
def _releaseLock():
KeyboardInterrupt:
💡 found cached instance metadata: /home/ml4ig1/.lamin/instance--jkobject--scprint.env 💡 loaded instance: jkobject/scprint
In [5]:
Copied!
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint, StochasticWeightAveraging, EarlyStopping, LearningRateMonitor, LearningRateFinder
seed_everything(42, workers=True)
from scprint import scPrint
from scprint.trainer import TrainingMode
from scdataloader import DataModule
import pandas as pd
from scdataloader.utils import load_genes
import torch
torch.set_float32_matmul_precision('medium')
%load_ext autoreload
%autoreload 2
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint, StochasticWeightAveraging, EarlyStopping, LearningRateMonitor, LearningRateFinder
seed_everything(42, workers=True)
from scprint import scPrint
from scprint.trainer import TrainingMode
from scdataloader import DataModule
import pandas as pd
from scdataloader.utils import load_genes
import torch
torch.set_float32_matmul_precision('medium')
%load_ext autoreload
%autoreload 2
Global seed set to 42
In [6]:
Copied!
# TODO: drop tissue & dev stage until part or is taken in account
hierarchical_labels = [
"cell_type_ontology_term_id", #1
# "tissue_ontology_term_id",
"disease_ontology_term_id", # 2
# "development_stage_ontology_term_id",
"assay_ontology_term_id", #3
'self_reported_ethnicity_ontology_term_id', #4
]
labels_to_pred = hierarchical_labels+[
'sex_ontology_term_id', #5
"organism_ontology_term_id", #6
]
all_labels = labels_to_pred+[
#'dataset_id',
'cell_culture',
"heat_diff",
"total_counts",
"nnz",
"dpt_group",
]
gene_emb = '../data/temp/embeddings.parquet'
d_model=256
# TODO: drop tissue & dev stage until part or is taken in account
hierarchical_labels = [
"cell_type_ontology_term_id", #1
# "tissue_ontology_term_id",
"disease_ontology_term_id", # 2
# "development_stage_ontology_term_id",
"assay_ontology_term_id", #3
'self_reported_ethnicity_ontology_term_id', #4
]
labels_to_pred = hierarchical_labels+[
'sex_ontology_term_id', #5
"organism_ontology_term_id", #6
]
all_labels = labels_to_pred+[
#'dataset_id',
'cell_culture',
"heat_diff",
"total_counts",
"nnz",
"dpt_group",
]
gene_emb = '../data/temp/embeddings.parquet'
d_model=256
In [7]:
Copied!
datamodule = DataModule(
collection_name="preprocessed dataset",
gene_embeddings=gene_emb,
all_labels=all_labels,
hierarchical_labels=hierarchical_labels,
organisms=["NCBITaxon:9606"],
how="random expr",
max_len=1200,
add_zero_genes=0,
# how much more you will see the most present vs less present category
weight_scaler=10,
label_to_weight=labels_to_pred,
label_to_pred=labels_to_pred,
batch_size=64,
num_workers=16,
train_oversampling=2,
validation_split=0.05,
do_gene_pos='../data/main/biomart.parquet',
test_split=0.05)
testfiles = datamodule.setup()
datamodule = DataModule(
collection_name="preprocessed dataset",
gene_embeddings=gene_emb,
all_labels=all_labels,
hierarchical_labels=hierarchical_labels,
organisms=["NCBITaxon:9606"],
how="random expr",
max_len=1200,
add_zero_genes=0,
# how much more you will see the most present vs less present category
weight_scaler=10,
label_to_weight=labels_to_pred,
label_to_pred=labels_to_pred,
batch_size=64,
num_workers=16,
train_oversampling=2,
validation_split=0.05,
do_gene_pos='../data/main/biomart.parquet',
test_split=0.05)
testfiles = datamodule.setup()
won't do any check but we recommend to have your dataset coming from local storage
83.14606741573034% are aligned
total dataset size is 97.032938749 Gb
---
dataset contains:
4926521 cells
70116 genes
11 labels
6 clss_to_pred
4 hierarchical_clss
4 join_vars
1 organisms
dataset contains 229 classes to predict
seeing a string: loading gene positions as biomart parquet file
these files will be considered test datasets:
/home/ml4ig1/scprint/.lamindb/BljRloq1xjcxRNDpejzI.h5ad
/home/ml4ig1/scprint/.lamindb/yBCKp6HmXuHa0cZptMo7.h5ad
perc test: 0.0057480725242011555
In [ ]:
Copied!
# check the geneposition thing
# check the geneposition thing
In [4]:
Copied!
embeddings = pd.read_parquet(gene_emb).loc[datamodule.genes]
if len(embeddings) == 0:
raise ValueError(
f"the gene embeddings file {gene_emb} does not contain any of the genes given to the model"
)
elif len(embeddings) < len(datamodule.genes):
print(
"Warning: only a subset of the genes available in the embeddings file."
)
print("number of genes: ", len(embeddings))
sembeddings = torch.nn.AdaptiveAvgPool1d(d_model)(
torch.tensor(embeddings.values)
)
embeddings = pd.read_parquet(gene_emb).loc[datamodule.genes]
if len(embeddings) == 0:
raise ValueError(
f"the gene embeddings file {gene_emb} does not contain any of the genes given to the model"
)
elif len(embeddings) < len(datamodule.genes):
print(
"Warning: only a subset of the genes available in the embeddings file."
)
print("number of genes: ", len(embeddings))
sembeddings = torch.nn.AdaptiveAvgPool1d(d_model)(
torch.tensor(embeddings.values)
)
In [57]:
Copied!
from anndata import AnnData
import scanpy as sc
import numpy as np
adata = AnnData(sembeddings.detach().numpy())
#sc.pp.neighbors(adata)
adata
from anndata import AnnData
import scanpy as sc
import numpy as np
adata = AnnData(sembeddings.detach().numpy())
#sc.pp.neighbors(adata)
adata
Out[57]:
AnnData object with n_obs × n_vars = 33890 × 256
In [58]:
Copied!
df = load_genes()
adata.obs = df.loc[datamodule.genes]
adata
df = load_genes()
adata.obs = df.loc[datamodule.genes]
adata
Out[58]:
AnnData object with n_obs × n_vars = 33890 × 256
obs: 'uid', 'symbol', 'stable_id', 'ncbi_gene_ids', 'biotype', 'description', 'synonyms', 'organism_id', 'public_source_id', 'created_at', 'updated_at', 'created_by_id', 'mt', 'ribo', 'hb', 'organism'
In [59]:
Copied!
adata.obs['symbol'] = adata.obs.symbol.astype(str)
adata.obs['symbol'] = adata.obs.symbol.astype(str)
In [60]:
Copied!
loc = adata.obs[adata.obs.symbol.str.startswith('MEF2') | adata.obs.symbol.str.startswith('IFNA') | adata.obs.symbol.str.startswith('ZEB') | adata.obs.symbol.str.startswith("CHEK")].sort_values(by="symbol").reset_index()
loc
loc = adata.obs[adata.obs.symbol.str.startswith('MEF2') | adata.obs.symbol.str.startswith('IFNA') | adata.obs.symbol.str.startswith('ZEB') | adata.obs.symbol.str.startswith("CHEK")].sort_values(by="symbol").reset_index()
loc
Out[60]:
| ensembl_gene_id | uid | symbol | stable_id | ncbi_gene_ids | biotype | description | synonyms | organism_id | public_source_id | created_at | updated_at | created_by_id | mt | ribo | hb | organism | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | ENSG00000149554 | 4Ortzt9Vugii | CHEK1 | None | 1111 | protein_coding | checkpoint kinase 1 [Source:HGNC Symbol;Acc:HG... | CHK1 | 2 | 9.0 | 2023-11-22 13:16:56.958347+00:00 | 2023-11-22 13:16:56.958354+00:00 | 1 | False | False | False | NCBITaxon:9606 |
| 1 | ENSG00000183765 | 7VqzGOaa4FM9 | CHEK2 | None | 11200 | protein_coding | checkpoint kinase 2 [Source:HGNC Symbol;Acc:HG... | CDS1|RAD53|BA444G7|CHK2|HUCDS1|PP1425 | 2 | 9.0 | 2023-11-22 13:16:57.275899+00:00 | 2023-11-22 13:16:57.275906+00:00 | 1 | False | False | False | NCBITaxon:9606 |
| 2 | ENSG00000197919 | QYLNkkiyikzH | IFNA1 | None | 3439 | protein_coding | interferon alpha 1 [Source:HGNC Symbol;Acc:HGN... | IFNA@|IFN|IFN-ALPHAD|IFN-ALPHA|IFL|IFNA13 | 2 | 9.0 | 2023-11-22 13:16:57.375774+00:00 | 2023-11-22 13:16:57.375780+00:00 | 1 | False | False | False | NCBITaxon:9606 |
| 3 | ENSG00000186803 | 1CFWiKSyXuiA | IFNA10 | None | 3446 | protein_coding | interferon alpha 10 [Source:HGNC Symbol;Acc:HG... | IFN-ALPHAC | 2 | 9.0 | 2023-11-22 13:16:57.311906+00:00 | 2023-11-22 13:16:57.311913+00:00 | 1 | False | False | False | NCBITaxon:9606 |
| 4 | ENSG00000233816 | t4crQ8ytw0iA | IFNA13 | None | 3447 | protein_coding | interferon alpha 13 [Source:HGNC Symbol;Acc:HG... | 2 | 9.0 | 2023-11-22 13:16:58.491687+00:00 | 2023-11-22 13:16:58.491694+00:00 | 1 | False | False | False | NCBITaxon:9606 | |
| 5 | ENSG00000228083 | 70yUMP3PRc9i | IFNA14 | None | 3448 | protein_coding | interferon alpha 14 [Source:HGNC Symbol;Acc:HG... | IFN-ALPHAH|LEIF2H | 2 | 9.0 | 2023-11-22 13:16:58.264176+00:00 | 2023-11-22 13:16:58.264183+00:00 | 1 | False | False | False | NCBITaxon:9606 |
| 6 | ENSG00000147885 | 6TLO49f0ILni | IFNA16 | None | 3449 | protein_coding | interferon alpha 16 [Source:HGNC Symbol;Acc:HG... | IFN-ALPHAO | 2 | 9.0 | 2023-11-22 13:16:56.948043+00:00 | 2023-11-22 13:16:56.948050+00:00 | 1 | False | False | False | NCBITaxon:9606 |
| 7 | ENSG00000234829 | 4dmwtJjZ0k6L | IFNA17 | None | 3451 | protein_coding | interferon alpha 17 [Source:HGNC Symbol;Acc:HG... | IFN-ALPHAI|LEIF2C1 | 2 | 9.0 | 2023-11-22 13:16:58.533573+00:00 | 2023-11-22 13:16:58.533580+00:00 | 1 | False | False | False | NCBITaxon:9606 |
| 8 | ENSG00000188379 | 6p1EyXksTdgB | IFNA2 | None | 3440 | protein_coding | interferon alpha 2 [Source:HGNC Symbol;Acc:HGN... | IFN-ALPHAA|IFNA | 2 | 9.0 | 2023-11-22 13:16:57.330804+00:00 | 2023-11-22 13:16:57.330810+00:00 | 1 | False | False | False | NCBITaxon:9606 |
| 9 | ENSG00000137080 | 4L5Ab7BN5Htd | IFNA21 | None | 3452 | protein_coding | interferon alpha 21 [Source:HGNC Symbol;Acc:HG... | IFN-ALPHAI | 2 | 9.0 | 2023-11-22 13:16:56.868721+00:00 | 2023-11-22 13:16:56.868728+00:00 | 1 | False | False | False | NCBITaxon:9606 |
| 10 | ENSG00000236637 | 4V6HGby0lycz | IFNA4 | None | 3441 | protein_coding | interferon alpha 4 [Source:HGNC Symbol;Acc:HGN... | MGC142200|IFN-ALPHA4A | 2 | 9.0 | 2023-11-22 13:16:58.605795+00:00 | 2023-11-22 13:16:58.605801+00:00 | 1 | False | False | False | NCBITaxon:9606 |
| 11 | ENSG00000147873 | 5AkZAtlQZxKP | IFNA5 | None | 3442 | protein_coding | interferon alpha 5 [Source:HGNC Symbol;Acc:HGN... | IFN-ALPHAG | 2 | 9.0 | 2023-11-22 13:16:56.947962+00:00 | 2023-11-22 13:16:56.947968+00:00 | 1 | False | False | False | NCBITaxon:9606 |
| 12 | ENSG00000120235 | 2CUBXjQb0mXr | IFNA6 | None | 3443 | protein_coding | interferon alpha 6 [Source:HGNC Symbol;Acc:HGN... | IFN-ALPHAK | 2 | 9.0 | 2023-11-22 13:16:56.745805+00:00 | 2023-11-22 13:16:56.745815+00:00 | 1 | False | False | False | NCBITaxon:9606 |
| 13 | ENSG00000214042 | 6BkFRyl7zNqq | IFNA7 | None | 3444 | protein_coding | interferon alpha 7 [Source:HGNC Symbol;Acc:HGN... | IFNA-J|IFN-ALPHAJ | 2 | 9.0 | 2023-11-22 13:16:57.976149+00:00 | 2023-11-22 13:16:57.976156+00:00 | 1 | False | False | False | NCBITaxon:9606 |
| 14 | ENSG00000120242 | 6fPid6Uqo6p5 | IFNA8 | None | 3445 | protein_coding | interferon alpha 8 [Source:HGNC Symbol;Acc:HGN... | IFN-ALPHAB | 2 | 9.0 | 2023-11-22 13:16:56.745846+00:00 | 2023-11-22 13:16:56.745856+00:00 | 1 | False | False | False | NCBITaxon:9606 |
| 15 | ENSG00000142166 | jSeUOkMz0Jiw | IFNAR1 | None | 3454 | protein_coding | interferon alpha and beta receptor subunit 1 [... | IFNAR|IFRC | 2 | 9.0 | 2023-11-22 13:16:56.908455+00:00 | 2023-11-22 13:16:56.908461+00:00 | 1 | False | False | False | NCBITaxon:9606 |
| 16 | ENSG00000159110 | 3hbb31iADZoQ | IFNAR2 | None | 3455 | protein_coding | interferon alpha and beta receptor subunit 2 [... | IFNABR | 2 | 9.0 | 2023-11-22 13:16:57.011051+00:00 | 2023-11-22 13:16:57.011058+00:00 | 1 | False | False | False | NCBITaxon:9606 |
| 17 | ENSG00000249624 | 5d9SmuIAUyCH | IFNAR2-IL10RB | None | 127882475 | protein_coding | IFNAR2-IL10RB readthrough [Source:NCBI gene (f... | 2 | 9.0 | 2023-11-22 13:16:58.872178+00:00 | 2023-11-22 13:16:58.872185+00:00 | 1 | False | False | False | NCBITaxon:9606 | |
| 18 | ENSG00000068305 | 5UhPj3IigoJb | MEF2A | None | 4205 | protein_coding | myocyte enhancer factor 2A [Source:HGNC Symbol... | RSRFC9|RSRFC4 | 2 | 9.0 | 2023-11-22 13:16:56.456236+00:00 | 2023-11-22 13:16:56.456243+00:00 | 1 | False | False | False | NCBITaxon:9606 |
| 19 | ENSG00000213999 | 2Zc2QwPmeQPj | MEF2B | None | 4207|100271849 | protein_coding | myocyte enhancer factor 2B [Source:HGNC Symbol... | RSRFR2 | 2 | 9.0 | 2023-11-22 13:16:57.973860+00:00 | 2023-11-22 13:16:57.973867+00:00 | 1 | False | False | False | NCBITaxon:9606 |
| 20 | ENSG00000081189 | 1VY5MChVVkz5 | MEF2C | None | 4208 | protein_coding | myocyte enhancer factor 2C [Source:HGNC Symbol... | 2 | 9.0 | 2023-11-22 13:16:56.480805+00:00 | 2023-11-22 13:16:56.480812+00:00 | 1 | False | False | False | NCBITaxon:9606 | |
| 21 | ENSG00000245864 | 559YoYs1XTaY | MEF2C-AS2 | None | 109729137 | lncRNA | MEF2C antisense RNA 2 [Source:HGNC Symbol;Acc:... | CTC-467M3.1 | 2 | 9.0 | 2023-11-22 13:16:58.813285+00:00 | 2023-11-22 13:16:58.813292+00:00 | 1 | False | False | False | NCBITaxon:9606 |
| 22 | ENSG00000116604 | 5RJ75hd8BnoW | MEF2D | None | 4209 | protein_coding | myocyte enhancer factor 2D [Source:HGNC Symbol... | 2 | 9.0 | 2023-11-22 13:16:56.708287+00:00 | 2023-11-22 13:16:56.708302+00:00 | 1 | False | False | False | NCBITaxon:9606 | |
| 23 | ENSG00000291761 | 57dzZziVJBoq | MEF2D | None | 4209 | protein_coding | myocyte enhancer factor 2D [Source:HGNC Symbol... | 2 | 9.0 | 2023-11-22 13:17:00.847707+00:00 | 2023-11-22 13:17:00.847714+00:00 | 1 | False | False | False | NCBITaxon:9606 | |
| 24 | ENSG00000148516 | 7JJCpKRlbelH | ZEB1 | None | 6935 | protein_coding | zinc finger E-box binding homeobox 1 [Source:H... | NIL-2-A|FECD6|AREB6|ZEB|PPCD3|ZFHX1A|BZP|ZFHEP... | 2 | 9.0 | 2023-11-22 13:16:56.951716+00:00 | 2023-11-22 13:16:56.951723+00:00 | 1 | False | False | False | NCBITaxon:9606 |
| 25 | ENSG00000237036 | 1MvXWONs6EFf | ZEB1-AS1 | None | 220930 | lncRNA | ZEB1 antisense RNA 1 [Source:HGNC Symbol;Acc:H... | 2 | 9.0 | 2023-11-22 13:16:58.623456+00:00 | 2023-11-22 13:16:58.623463+00:00 | 1 | False | False | False | NCBITaxon:9606 | |
| 26 | ENSG00000169554 | 6A2X62YmiKpG | ZEB2 | None | 9839 | protein_coding | zinc finger E-box binding homeobox 2 [Source:H... | SIP-1|SIP1|ZFHX1B|KIAA0569 | 2 | 9.0 | 2023-11-22 13:16:57.127062+00:00 | 2023-11-22 13:16:57.127069+00:00 | 1 | False | False | False | NCBITaxon:9606 |
In [62]:
Copied!
import matplotlib.pyplot as plt
correlation_matrix = np.corrcoef(adata[loc['ensembl_gene_id']].X)
plt.figure(figsize=(10,10))
plt.imshow(correlation_matrix, cmap='hot', interpolation='nearest', vmin=0.5)
plt.colorbar()
plt.xticks(ticks=np.arange(len(loc['symbol'])), labels=loc['symbol'], rotation=90)
plt.yticks(ticks=np.arange(len(loc['symbol'])), labels=loc['symbol'])
plt.title('Correlation Matrix')
plt.show()
import matplotlib.pyplot as plt
correlation_matrix = np.corrcoef(adata[loc['ensembl_gene_id']].X)
plt.figure(figsize=(10,10))
plt.imshow(correlation_matrix, cmap='hot', interpolation='nearest', vmin=0.5)
plt.colorbar()
plt.xticks(ticks=np.arange(len(loc['symbol'])), labels=loc['symbol'], rotation=90)
plt.yticks(ticks=np.arange(len(loc['symbol'])), labels=loc['symbol'])
plt.title('Correlation Matrix')
plt.show()
In [ ]:
Copied!
MEF2s
ZEB1, ZEB2
IFN1
IFN2
MEF2s
ZEB1, ZEB2
IFN1
IFN2
In [7]:
Copied!
sc.tl.umap(adata)
sc.tl.umap(adata)
In [9]:
Copied!
sc.pl.umap(adata, color=['biotype'])
sc.pl.umap(adata, color=['biotype'])
/home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:394: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored cax = scatter(
In [ ]:
Copied!
# make a GRN prediction task
# test on my benGRN
# make a GRN prediction task
# test on my benGRN
In [ ]:
Copied!
# switch to using the new laminDB mappedCollection
# switch to using the new laminDB mappedCollection
In [ ]:
Copied!
# follow the path of the data through the model and the transformers
# follow the loss computations
# follow what happens to the uncalled functions of optimizations etc.. how the lr gets updated (it never gets updated on my end)
# follow what happens to the gradients scaling (weight) and some other values
# test DAB loss
# test with regular transformer and make it work with current weights
# look at why we can't visualize the model graph
# follow the path of the data through the model and the transformers
# follow the loss computations
# follow what happens to the uncalled functions of optimizations etc.. how the lr gets updated (it never gets updated on my end)
# follow what happens to the gradients scaling (weight) and some other values
# test DAB loss
# test with regular transformer and make it work with current weights
# look at why we can't visualize the model graph
In [ ]:
Copied!
# look at why we can't see the parameters in wandb?
# look at plotting the params of the dataloader too
# impact of gradient clipping
# what happens when distributed random sampling in case of ddp??
# find a way to relaunch job after a slurm sigterm
# look at why we can't see the parameters in wandb?
# look at plotting the params of the dataloader too
# impact of gradient clipping
# what happens when distributed random sampling in case of ddp??
# find a way to relaunch job after a slurm sigterm
In [ ]:
Copied!
# import scGPT's weights
# import scGPT's weights
In [8]:
Copied!
model = scPrint(
genes = datamodule.genes,
d_model = d_model,
nhead = 4,
nlayers = 4,
layers_cls = [],
labels = datamodule.labels,
cls_hierarchy = datamodule.cls_hierarchy,
dropout= 0.1,
transformer = "flash",
precpt_gene_emb = gene_emb,
gene_pos_enc = datamodule.gene_pos,
mvc_decoder = "inner product",
label_decoders = datamodule.decoders,
fused_dropout_add_ln = False
)
model = scPrint(
genes = datamodule.genes,
d_model = d_model,
nhead = 4,
nlayers = 4,
layers_cls = [],
labels = datamodule.labels,
cls_hierarchy = datamodule.cls_hierarchy,
dropout= 0.1,
transformer = "flash",
precpt_gene_emb = gene_emb,
gene_pos_enc = datamodule.gene_pos,
mvc_decoder = "inner product",
label_decoders = datamodule.decoders,
fused_dropout_add_ln = False
)
scPrint(
(gene_encoder): GeneEncoder(
(embedding): Embedding(33890, 256)
(enc_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(expr_encoder): ContinuousValueEncoder(
(encoder): ModuleList(
(0): Linear(in_features=1, out_features=256, bias=True)
(1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(2): ReLU()
(3): Dropout(p=0.1, inplace=False)
)
)
(pos_encoder): PositionalEncoding(
(dropout): Dropout(p=0.1, inplace=False)
)
(label_encoder): CategoryValueEncoder(
(embedding): Embedding(8, 256)
(enc_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
)
(depth_decoder): ContinuousValueEncoder(
(encoder): ModuleList(
(0): Linear(in_features=1, out_features=256, bias=True)
(1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(2): ReLU()
(3): Dropout(p=0.1, inplace=False)
)
)
(transformer): FlashTransformerEncoder(
(blocks): ModuleList(
(0-3): 4 x Block(
(mixer): MHA(
(Wqkv): Linear(in_features=256, out_features=768, bias=True)
(inner_attn): FlashSelfAttention()
(inner_cross_attn): FlashCrossAttention(
(drop): Dropout(p=0.1, inplace=False)
)
(out_proj): Linear(in_features=256, out_features=256, bias=True)
)
(dropout1): Dropout(p=0.1, inplace=False)
(drop_path1): StochasticDepth(p=0.0, mode=row)
(norm1): LayerNorm((256,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=256, out_features=1024, bias=True)
(activation): GELU(approximate='none')
(fc2): Linear(in_features=1024, out_features=256, bias=True)
)
(dropout2): Dropout(p=0.1, inplace=False)
(drop_path2): StochasticDepth(p=0.0, mode=row)
(norm2): LayerNorm((256,), eps=1e-06, elementwise_affine=True)
)
)
(dropout): Dropout(p=0.1, inplace=False)
(drop_path): StochasticDepth(p=0.0, mode=row)
(norm): LayerNorm((256,), eps=1e-06, elementwise_affine=True)
)
(expr_decoder): ExprDecoder(
(fc): Sequential(
(0): Linear(in_features=256, out_features=256, bias=True)
(1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(2): LeakyReLU(negative_slope=0.01)
(3): Dropout(p=0.1, inplace=False)
(4): Linear(in_features=256, out_features=256, bias=True)
(5): LeakyReLU(negative_slope=0.01)
)
(pred_var_zero): Linear(in_features=256, out_features=3, bias=True)
)
(cls_decoders): ModuleDict(
(cell_type_ontology_term_id): ClsDecoder(
(decoder): Sequential()
(out_layer): Linear(in_features=256, out_features=190, bias=True)
)
(disease_ontology_term_id): ClsDecoder(
(decoder): Sequential()
(out_layer): Linear(in_features=256, out_features=18, bias=True)
)
(assay_ontology_term_id): ClsDecoder(
(decoder): Sequential()
(out_layer): Linear(in_features=256, out_features=11, bias=True)
)
(self_reported_ethnicity_ontology_term_id): ClsDecoder(
(decoder): Sequential()
(out_layer): Linear(in_features=256, out_features=6, bias=True)
)
(sex_ontology_term_id): ClsDecoder(
(decoder): Sequential()
(out_layer): Linear(in_features=256, out_features=2, bias=True)
)
(organism_ontology_term_id): ClsDecoder(
(decoder): Sequential()
(out_layer): Linear(in_features=256, out_features=2, bias=True)
)
)
(mvc_decoder): MVCDecoder(
(gene2query): Linear(in_features=256, out_features=256, bias=True)
(query_activation): Sigmoid()
(pred_var_zero): Linear(in_features=256, out_features=768, bias=False)
)
)
In [ ]:
Copied!
# create a function to transform an scGPT checkpoint to an scPrint's
# ckpt = torch.load("../../scGPT/save/model_e6.pt")
# scPrint.load_from_checkpoint("../../scGPT/save/model_e6.pt")
# create a function to transform an scGPT checkpoint to an scPrint's
# ckpt = torch.load("../../scGPT/save/model_e6.pt")
# scPrint.load_from_checkpoint("../../scGPT/save/model_e6.pt")
In [ ]:
Copied!
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.loggers import WandbLogger
wandb_logger = WandbLogger(project="scprint_test", save_dir="../data/tensorboard")
wandb_logger.watch(model, log='all', log_freq=50, log_graph=True)
#tlogger = TensorBoardLogger(save_dir="../data/tensorboard")
#tlogger.log_graph(model, i)
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.loggers import WandbLogger
wandb_logger = WandbLogger(project="scprint_test", save_dir="../data/tensorboard")
wandb_logger.watch(model, log='all', log_freq=50, log_graph=True)
#tlogger = TensorBoardLogger(save_dir="../data/tensorboard")
#tlogger.log_graph(model, i)
2024-02-19 16:41:43,615:ERROR - Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving. wandb: Currently logged in as: jkobject (ml4ig). Use `wandb login --relogin` to force relogin
wandb version 0.16.3 is available! To upgrade, please run:
$ pip install wandb --upgrade
Tracking run with wandb version 0.16.2
Run data is saved locally in
../data/tensorboard/wandb/run-20240219_164145-iuealg88
View project at https://wandb.ai/ml4ig/scprint_test
wandb: logging graph, to disable use `wandb.watch(log_graph=False)`
In [ ]:
Copied!
#from lightning.pytorch.profilers import PyTorchProfiler
#pytorch_prof = PyTorchProfiler("../data/tensorboard", emit_nvtx=False, group_by_input_shape=True, record_shapes=True, profile_memory=True, with_stack=True, on_trace_ready=torch.profiler.tensorboard_trace_handler("../data/tensorboard/"),)
#from lightning.pytorch.profilers import PyTorchProfiler
#pytorch_prof = PyTorchProfiler("../data/tensorboard", emit_nvtx=False, group_by_input_shape=True, record_shapes=True, profile_memory=True, with_stack=True, on_trace_ready=torch.profiler.tensorboard_trace_handler("../data/tensorboard/"),)
In [14]:
Copied!
chckp = ModelCheckpoint(monitor="val_loss", save_top_k=-1)
trainingmode = TrainingMode(do_denoise=True, noise=[0.3], do_cce=True, cce_sim=0.5, do_ecs=True, ecs_threshold = 0.3, ecs_scale = 10.0, do_mvc=False, do_adv_cls=False, do_next_tp=False, class_scale = 5000.0, mask_ratio=[0.15, 0.3], warmup_duration= 500, weight_decay= 0.1, fused_adam= True,lr_patience= 1)
es = EarlyStopping(patience=2, monitor='val_loss')
swa = StochasticWeightAveraging(swa_lrs= 0.01)
lrm = LearningRateMonitor(logging_interval="step")
#lrf = LearningRateFinder(mode="exponential",)
# TODO: to check that the class hierarchy are really ordered from 1-2-3-4... as well (oredered dict)
trainer = Trainer(precision="16-mixed", gradient_clip_val=10, max_time={"hours": 3}, limit_train_batches=5000, limit_test_batches=0.03, limit_val_batches=1000, callbacks=[chckp, trainingmode, es, lrm], accumulate_grad_batches=2, reload_dataloaders_every_n_epochs=1) #detect_anomaly=True, fast_dev_run=20, overfit_batches=10, limit_train_batches=1, limit_val_batches=0
#logger=wandb_logger,
chckp = ModelCheckpoint(monitor="val_loss", save_top_k=-1)
trainingmode = TrainingMode(do_denoise=True, noise=[0.3], do_cce=True, cce_sim=0.5, do_ecs=True, ecs_threshold = 0.3, ecs_scale = 10.0, do_mvc=False, do_adv_cls=False, do_next_tp=False, class_scale = 5000.0, mask_ratio=[0.15, 0.3], warmup_duration= 500, weight_decay= 0.1, fused_adam= True,lr_patience= 1)
es = EarlyStopping(patience=2, monitor='val_loss')
swa = StochasticWeightAveraging(swa_lrs= 0.01)
lrm = LearningRateMonitor(logging_interval="step")
#lrf = LearningRateFinder(mode="exponential",)
# TODO: to check that the class hierarchy are really ordered from 1-2-3-4... as well (oredered dict)
trainer = Trainer(precision="16-mixed", gradient_clip_val=10, max_time={"hours": 3}, limit_train_batches=5000, limit_test_batches=0.03, limit_val_batches=1000, callbacks=[chckp, trainingmode, es, lrm], accumulate_grad_batches=2, reload_dataloaders_every_n_epochs=1) #detect_anomaly=True, fast_dev_run=20, overfit_batches=10, limit_train_batches=1, limit_val_batches=0
#logger=wandb_logger,
Using 16bit Automatic Mixed Precision (AMP) GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs
In [44]:
Copied!
model.labels
model.labels
Out[44]:
['cell_type_ontology_term_id', 'disease_ontology_term_id', 'assay_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'sex_ontology_term_id', 'organism_ontology_term_id']
In [13]:
Copied!
# sanity. should be overfiting.
trainer = Trainer(precision="16-mixed", max_epochs=1000, limit_val_batches=0, check_val_every_n_epoch=1000, log_every_n_steps=1000, detect_anomaly=False, overfit_batches=1,
reload_dataloaders_every_n_epochs=1000) #logger=wandb_logger) limit_train_batches=1
# sanity. should be overfiting.
trainer = Trainer(precision="16-mixed", max_epochs=1000, limit_val_batches=0, check_val_every_n_epoch=1000, log_every_n_steps=1000, detect_anomaly=False, overfit_batches=1,
reload_dataloaders_every_n_epochs=1000) #logger=wandb_logger) limit_train_batches=1
Using 16bit Automatic Mixed Precision (AMP) GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs `Trainer(overfit_batches=1)` was configured so 1 batch will be used. TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs `Trainer(overfit_batches=1)` was configured so 1 batch will be used.
In [15]:
Copied!
trainer.fit(model, datamodule=datamodule)
trainer.fit(model, datamodule=datamodule)
these files will be considered test datasets:
/home/ml4ig1/scprint/.lamindb/BljRloq1xjcxRNDpejzI.h5ad
/home/ml4ig1/scprint/.lamindb/yBCKp6HmXuHa0cZptMo7.h5ad
perc test: 0.0057480725242011555
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] | Name | Type | Params ---------------------------------------------------------- 0 | gene_encoder | GeneEncoder | 8.7 M 1 | expr_encoder | ContinuousValueEncoder | 1.0 K 2 | pos_encoder | PositionalEncoding | 0 3 | label_encoder | CategoryValueEncoder | 2.6 K 4 | depth_decoder | ContinuousValueEncoder | 1.0 K 5 | transformer | FlashTransformerEncoder | 3.2 M 6 | expr_decoder | ExprDecoder | 132 K 7 | cls_decoders | ModuleDict | 58.9 K 8 | mvc_decoder | MVCDecoder | 262 K ---------------------------------------------------------- 3.6 M Trainable params 8.7 M Non-trainable params 12.3 M Total params 49.179 Total estimated model params size (MB)
Sanity Checking: 0it [00:00, ?it/s]
/home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/anndata/_core/anndata.py:183: ImplicitModificationWarning: Transforming to str index.
warnings.warn("Transforming to str index.", ImplicitModificationWarning)
WARNING: You’re trying to run this on 256 dimensions of `.X`, if you really want this, set `use_rep='X'`.
Falling back to preprocessing with `sc.pp.pca` and default params.
/home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/anndata/_core/anndata.py:522: FutureWarning: The dtype argument is deprecated and will be removed in late 2024. warnings.warn(
AnnData object with n_obs × n_vars = 106 × 256
obs: 'pred_cell_type_ontology_term_id', 'pred_disease_ontology_term_id', 'pred_assay_ontology_term_id', 'pred_self_reported_ethnicity_ontology_term_id', 'pred_sex_ontology_term_id', 'pred_organism_ontology_term_id', 'cell_type_ontology_term_id', 'disease_ontology_term_id', 'assay_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'sex_ontology_term_id', 'organism_ontology_term_id', 'conv_cell_type_ontology_term_id', 'conv_pred_cell_type_ontology_term_id', 'conv_assay_ontology_term_id', 'conv_pred_assay_ontology_term_id', 'leiden'
uns: 'neighbors', 'umap', 'leiden'
obsm: 'X_pca', 'X_umap'
obsp: 'distances', 'connectivities'
/home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:394: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored cax = scatter( /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:394: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored cax = scatter( /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:394: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored cax = scatter( /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:394: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored cax = scatter( /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:394: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored cax = scatter( /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:394: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored cax = scatter( /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:394: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored cax = scatter( /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:394: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored cax = scatter( /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:394: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored cax = scatter( /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:394: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored cax = scatter( /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:394: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored cax = scatter( /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:394: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored cax = scatter(
couldn't log to tensorboard couldn't log to wandb
Training: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]
/home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/anndata/_core/anndata.py:183: ImplicitModificationWarning: Transforming to str index.
warnings.warn("Transforming to str index.", ImplicitModificationWarning)
WARNING: You’re trying to run this on 256 dimensions of `.X`, if you really want this, set `use_rep='X'`.
Falling back to preprocessing with `sc.pp.pca` and default params.
/home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/anndata/_core/anndata.py:522: FutureWarning: The dtype argument is deprecated and will be removed in late 2024. warnings.warn(
AnnData object with n_obs × n_vars = 10016 × 256
obs: 'pred_cell_type_ontology_term_id', 'pred_disease_ontology_term_id', 'pred_assay_ontology_term_id', 'pred_self_reported_ethnicity_ontology_term_id', 'pred_sex_ontology_term_id', 'pred_organism_ontology_term_id', 'cell_type_ontology_term_id', 'disease_ontology_term_id', 'assay_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'sex_ontology_term_id', 'organism_ontology_term_id', 'conv_cell_type_ontology_term_id', 'conv_pred_cell_type_ontology_term_id', 'conv_assay_ontology_term_id', 'conv_pred_assay_ontology_term_id', 'leiden'
uns: 'neighbors', 'umap', 'leiden'
obsm: 'X_pca', 'X_umap'
obsp: 'distances', 'connectivities'
/home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:394: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored cax = scatter( /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:394: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored cax = scatter( /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:394: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored cax = scatter( /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:394: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored cax = scatter( /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:394: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored cax = scatter( /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:394: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored cax = scatter( /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:394: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored cax = scatter( /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:394: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored cax = scatter( /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:394: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored cax = scatter( /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:394: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored cax = scatter( /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:394: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored cax = scatter( /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:394: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored cax = scatter(
couldn't log to tensorboard couldn't log to wandb
Validation: 0it [00:00, ?it/s]
/home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/anndata/_core/anndata.py:183: ImplicitModificationWarning: Transforming to str index.
warnings.warn("Transforming to str index.", ImplicitModificationWarning)
WARNING: You’re trying to run this on 256 dimensions of `.X`, if you really want this, set `use_rep='X'`.
Falling back to preprocessing with `sc.pp.pca` and default params.
/home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/anndata/_core/anndata.py:522: FutureWarning: The dtype argument is deprecated and will be removed in late 2024. warnings.warn(
AnnData object with n_obs × n_vars = 10007 × 256
obs: 'pred_cell_type_ontology_term_id', 'pred_disease_ontology_term_id', 'pred_assay_ontology_term_id', 'pred_self_reported_ethnicity_ontology_term_id', 'pred_sex_ontology_term_id', 'pred_organism_ontology_term_id', 'cell_type_ontology_term_id', 'disease_ontology_term_id', 'assay_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'sex_ontology_term_id', 'organism_ontology_term_id', 'conv_cell_type_ontology_term_id', 'conv_pred_cell_type_ontology_term_id', 'conv_assay_ontology_term_id', 'conv_pred_assay_ontology_term_id', 'leiden'
uns: 'neighbors', 'umap', 'leiden'
obsm: 'X_pca', 'X_umap'
obsp: 'distances', 'connectivities'
/home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:394: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored cax = scatter( /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:394: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored cax = scatter( /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:394: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored cax = scatter( /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:394: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored cax = scatter( /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:394: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored cax = scatter( /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:394: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored cax = scatter( /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:394: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored cax = scatter( /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:394: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored cax = scatter( /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:394: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored cax = scatter( /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:394: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored cax = scatter( /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:394: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored cax = scatter( /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:1251: FutureWarning: The default value of 'ignore' for the `na_action` parameter in pandas.Categorical.map is deprecated and will be changed to 'None' in a future version. Please set na_action to the desired value to avoid seeing this warning color_vector = pd.Categorical(values.map(color_map)) /home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/scanpy/plotting/_tools/scatterplots.py:394: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored cax = scatter(
couldn't log to tensorboard couldn't log to wandb
/home/ml4ig1/miniconda3/envs/scprint/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:53: UserWarning: Detected KeyboardInterrupt, attempting graceful shutdown...
rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
In [ ]:
Copied!
%reload_ext tensorboard
%tensorboard --logdir="../data/tensorboard"
%reload_ext tensorboard
%tensorboard --logdir="../data/tensorboard"
In [16]:
Copied!
#wandb_logger.finalize(status="aborted")
torch.cuda.empty_cache()
#wandb_logger.finalize(status="aborted")
torch.cuda.empty_cache()
In [ ]:
Copied!
torch.cuda.empty_cache()
dict_sum_condition = {}
model.eval()
with torch.no_grad(), torch.cuda.amp.autocast(enabled=True):
device = next(model.parameters()).device
for batch in tqdm(datamodule.test_dataloader()):
batch_size = batch['x'].size(0)
# Replicate the operations in model forward pass
src_embs = model._encoder(
torch.tensor(all_gene_ids[i : i + batch_size], dtype=torch.long).to(device)
)
val_embs = model.value_encoder(
torch.tensor(all_values[i : i + batch_size], dtype=torch.float).to(device)
)
total_embs = src_embs + val_embs
# total_embs = model.layer(total_embs.permute(0, 2, 1)).permute(0, 2, 1)
# Send total_embs to attention layers for attention operations
# Retrieve the output from second to last layer
for layer in model.transformer_encoder.layers[:layer_num]:
total_embs = layer(
total_embs,
src_key_padding_mask=src_key_padding_mask[i : i + batch_size].to(
device
),
)
# Send total_embs to the last layer in flash-attn
# https://github.com/HazyResearch/flash-attention/blob/1b18f1b7a133c20904c096b8b222a0916e1b3d37/flash_attn/flash_attention.py#L90
qkv = model.transformer_encoder.layers[layer_num].self_attn.Wqkv(
total_embs
)
# Retrieve q, k, and v from flast-attn wrapper
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=8)
q = qkv[:, :, 0, :, :]
k = qkv[:, :, 1, :, :]
v = qkv[:, :, 2, :, :]
# https://towardsdatascience.com/illustrated-self-attention-2d627e33b20a
# q = [batch, gene, n_heads, n_hid]
# k = [batch, gene, n_heads, n_hid]
# attn_scores = [batch, n_heads, gene, gene]
attn_scores = q.permute(0, 2, 1, 3) @ k.permute(0, 2, 3, 1)
# apply softmax to get attention weights
attn_scores = softmax(attn_scores, dim=-1)
if i == 0:
sm_attn_scores = attn_scores.sum(0).detach().cpu().numpy()
else:
# take the sum
sm_attn_scores += attn_scores.sum(0).detach().cpu().numpy()
return sm_attn_scores
torch.cuda.empty_cache()
dict_sum_condition = {}
model.eval()
with torch.no_grad(), torch.cuda.amp.autocast(enabled=True):
device = next(model.parameters()).device
for batch in tqdm(datamodule.test_dataloader()):
batch_size = batch['x'].size(0)
# Replicate the operations in model forward pass
src_embs = model._encoder(
torch.tensor(all_gene_ids[i : i + batch_size], dtype=torch.long).to(device)
)
val_embs = model.value_encoder(
torch.tensor(all_values[i : i + batch_size], dtype=torch.float).to(device)
)
total_embs = src_embs + val_embs
# total_embs = model.layer(total_embs.permute(0, 2, 1)).permute(0, 2, 1)
# Send total_embs to attention layers for attention operations
# Retrieve the output from second to last layer
for layer in model.transformer_encoder.layers[:layer_num]:
total_embs = layer(
total_embs,
src_key_padding_mask=src_key_padding_mask[i : i + batch_size].to(
device
),
)
# Send total_embs to the last layer in flash-attn
# https://github.com/HazyResearch/flash-attention/blob/1b18f1b7a133c20904c096b8b222a0916e1b3d37/flash_attn/flash_attention.py#L90
qkv = model.transformer_encoder.layers[layer_num].self_attn.Wqkv(
total_embs
)
# Retrieve q, k, and v from flast-attn wrapper
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=8)
q = qkv[:, :, 0, :, :]
k = qkv[:, :, 1, :, :]
v = qkv[:, :, 2, :, :]
# https://towardsdatascience.com/illustrated-self-attention-2d627e33b20a
# q = [batch, gene, n_heads, n_hid]
# k = [batch, gene, n_heads, n_hid]
# attn_scores = [batch, n_heads, gene, gene]
attn_scores = q.permute(0, 2, 1, 3) @ k.permute(0, 2, 3, 1)
# apply softmax to get attention weights
attn_scores = softmax(attn_scores, dim=-1)
if i == 0:
sm_attn_scores = attn_scores.sum(0).detach().cpu().numpy()
else:
# take the sum
sm_attn_scores += attn_scores.sum(0).detach().cpu().numpy()
return sm_attn_scores
In [ ]:
Copied!
----
# TODO: connect with maestro people to ask for longer compute time
# TODO: do the same to jean zay (0.5 day)
------
# TODO: make a model benchmark package (continue from where I left off) (4 days)
# TODO: make a task function & make a benchmark function (1 day) (*denoising, *perturbation prediction)
------
# TODO: debug the gene embedding creation
# TODO: create embedding & make it work for the 4-5 species in the dataset (1 days)
# TODO: find the neighboors and next time point cells (1 days)
# TODO: create a version with next time point and neighboors task (1 days)
# TODO: make a trajectory prediction task (predict future cell type/s, expression) and benchmark (similarity to known future cell, similarity to known future expression) (1 days)
------
# TODO: run a large training on maestro (0.5 day)
------
# TODO: add KO & drug datasets
# TODO: create a version with KO and drug effect prediction
----
# TODO: connect with maestro people to ask for longer compute time
# TODO: do the same to jean zay (0.5 day)
------
# TODO: make a model benchmark package (continue from where I left off) (4 days)
# TODO: make a task function & make a benchmark function (1 day) (*denoising, *perturbation prediction)
------
# TODO: debug the gene embedding creation
# TODO: create embedding & make it work for the 4-5 species in the dataset (1 days)
# TODO: find the neighboors and next time point cells (1 days)
# TODO: create a version with next time point and neighboors task (1 days)
# TODO: make a trajectory prediction task (predict future cell type/s, expression) and benchmark (similarity to known future cell, similarity to known future expression) (1 days)
------
# TODO: run a large training on maestro (0.5 day)
------
# TODO: add KO & drug datasets
# TODO: create a version with KO and drug effect prediction